"""
extract faces from a list of images which contains faces
I am using this to extract 2 man faces as dataset

usage:

# extract all image faces out, this will save all faces image into a new directory
python3 extract_faces.py --images_dir ./trump

# extract all faces of a man and save into a new directory
python3 extract_faces.py --video /path/to/video_file.mp4
"""
import sys
import dlib
from skimage import io
import os
import argparse
import cv2
import numpy as np


def parse_args():
    arg_parser = argparse.ArgumentParser('----------- face extractor ----------')
    arg_parser.add_argument('--images_dir', '-i', default='./data/gal_gadot')
    arg_parser.add_argument('--video', '-v')

    return arg_parser.parse_args()


def detect_faces_on_single_image(img_f, is_show=False):
    detector = dlib.get_frontal_face_detector()
    win = dlib.image_window()

    print("-- Processing file: {}".format(img_f))
    save_p = os.path.join(os.path.dirname(img_f)+'_faces', os.path.basename(img_f).split('.')[0])
    if not os.path.exists(os.path.dirname(save_p)):
        os.makedirs(os.path.dirname(save_p))
    image = cv2.imread(img_f, cv2.COLOR_BGR2RGB)
    # cv2.imshow('image', image)
    # cv2.waitKey(0)
    try:
        detections = detector(image, 1)
        print("     Number of faces detected: {}".format(len(detections)))
        for i, d in enumerate(detections):
            print("     Detection {}: Left: {} Top: {} Right: {} Bottom: {}".format(
                i, d.left(), d.top(), d.right(), d.bottom()))
            # TODO: crop this face path out and save it into local
            # x = int(d.left()/2)
            # y = int(d.top()/2)
            # w = int(d.width()/2)
            # h = int(d.height()/2)

            x = int(d.left())
            y = int(d.top())
            w = int(d.width())
            h = int(d.height())

            face_patch = np.array(image)[y: y+h, x: x+w, 0:3]
            # print(face_patch.shape)
            cv2.rectangle(image, (x, y), (x+w, y+h), (0, 0, 255), 2)

            save_p_i = '{}_{}.jpg'.format(save_p, i)
            print('     saved this face patch into {}'.format(save_p_i))
            cv2.imwrite(save_p_i, face_patch)
            if is_show:
                cv2.imshow('image', image)
                cv2.imshow('face patch', face_patch)
                cv2.waitKey(0)
    except Exception as e:
        print(e)
        print('     we ignore this image not valid: {}'.format(img_f))

    # win.clear_overlay()
    # win.set_image(image)
    # win.add_overlay(detections)
    # dlib.hit_enter_to_continue()


def extract_faces_from_images(images_dir):
    all_images = [os.path.join(images_dir, i) for i in os.listdir(images_dir)]
    for img_f in all_images:
        detect_faces_on_single_image(img_f)


def extract_faces_from_video():
    pass


if __name__ == '__main__':
    args = parse_args()
    if args.video:
        video_f = args.video
        pass
    else:
        extract_faces_from_images(args.images_dir)


